{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "01d7fb38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "SBERT_MODEL = \"all-MiniLM-L6-v2\"\n",
    "from collections import Counter\n",
    "import nltk\n",
    "import re\n",
    "from nltk import sent_tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a7c68613",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration jjmachan--NSFW-questions-inter-cleaned_df-90257bbb92c45b4c\n",
      "Found cached dataset parquet (/home/shahul/.cache/huggingface/datasets/jjmachan___parquet/jjmachan--NSFW-questions-inter-cleaned_df-90257bbb92c45b4c/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n",
      "Using custom data configuration allenai--prosocial-dialog-ebbad39ca08b6d44\n",
      "Found cached dataset json (/home/shahul/.cache/huggingface/datasets/allenai___json/allenai--prosocial-dialog-ebbad39ca08b6d44/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
     ]
    }
   ],
   "source": [
    "nsfw_dataset = load_dataset(\"jjmachan/NSFW-questions-inter-cleaned_df\", split=\"train\")\n",
    "pro_social_dataset = load_dataset(\"allenai/prosocial-dialog\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "59919b39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['title', 'subreddit', 'post_id', 'score', 'link_flair_text', 'is_self', 'over_18', 'upvote_ratio', 'is_question', 'C1', 'C2', 'C3', 'C4', 'C5'],\n",
       "    num_rows: 12858\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nsfw_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "172c9d25",
   "metadata": {},
   "outputs": [],
   "source": [
    "def match_rot_safetylabels(dataset):\n",
    "    rots = [item[\"rots\"] for item in dataset]\n",
    "    safety_annotations = [item[\"safety_label\"] for item in dataset]\n",
    "    results = {}\n",
    "    for rots, sfty in zip(rots, safety_annotations):\n",
    "        for rot in rots:\n",
    "            if rot not in results.keys():\n",
    "                results[rot] = sfty\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "850a3045",
   "metadata": {},
   "outputs": [],
   "source": [
    "rot_sfty = match_rot_safetylabels(pro_social_dataset)\n",
    "all_rots = list(set(rot_sfty.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61ad08d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_vectorizer(model=SBERT_MODEL):\n",
    "    return SentenceTransformer(model)\n",
    "\n",
    "\n",
    "def vectorize_text(model, texts):\n",
    "    return model.encode(texts, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5aabe078",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_vectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "763dd04e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69507de2c38e463e96de8431e8dfcd40",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/3630 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "rot_vector = vectorize_text(model, all_rots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7ca6ebe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.spatial as sp\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "\n",
    "THRESHOLD = 0.65\n",
    "\n",
    "\n",
    "def match_query_rot(q, m):\n",
    "    cosine_sim = 1 - sp.distance.cdist(q, m, \"cosine\")\n",
    "    sim_indices = np.argwhere(cosine_sim >= THRESHOLD)\n",
    "    return sim_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6ba196be",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100\n",
    "\n",
    "\n",
    "def match_rot_post(dataset):\n",
    "    dic = {}\n",
    "    posts = [item[\"title\"] for item in dataset]\n",
    "    post_vector = vectorize_text(model, posts)\n",
    "    for idx in tqdm(range(0, len(post_vector), BATCH_SIZE)):\n",
    "        sim_indices = match_query_rot(post_vector[idx : idx + BATCH_SIZE], rot_vector)\n",
    "        for post_idx, rot_idx in sim_indices:\n",
    "            rot = all_rots[rot_idx]\n",
    "            dic.update({dataset[int(post_idx) + idx][\"post_id\"]: {\"rots\": [rot], \"safety_label\": rot_sfty.get(rot)}})\n",
    "    return dic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d6b727c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "67a45e34f79248139c8127692f8b325b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/402 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████| 129/129 [04:03<00:00,  1.89s/it]\n"
     ]
    }
   ],
   "source": [
    "result_dict = match_rot_post(nsfw_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fc13c132",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Turaround perc 11.214807901695442\n"
     ]
    }
   ],
   "source": [
    "print(\"Turaround perc\", len(result_dict) / len(nsfw_dataset) * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b2dfc16e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_stopwords(example):\n",
    "    stopwords = [\"Ladies\", \"Women\", \"Gals\", \"Men\", \"guys\"]\n",
    "    regex = \"\".join([f\"{word}(,)?|\" for word in stopwords])\n",
    "    example[\"title\"] = re.sub(regex, \"\", example[\"title\"], flags=re.IGNORECASE)\n",
    "    return example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "24437242",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_rot_label(example):\n",
    "    post_id = example[\"post_id\"]\n",
    "    if post_id in result_dict.keys():\n",
    "        example[\"rots\"] = result_dict.get(post_id)[\"rots\"]\n",
    "        example[\"safety_label\"] = result_dict.get(post_id)[\"safety_label\"]\n",
    "\n",
    "    return example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c6140d55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_response(example):\n",
    "    comments = comments_df[comments_df[\"post_id\"] == example[\"post_id\"]][\n",
    "        [\"C1\", \"C2\", \"C3\", \"C4\", \"C5\"]\n",
    "    ].values.tolist()[0]\n",
    "    comments = [str(comment) for comment in comments]\n",
    "    comments = [\n",
    "        comment for comment in comments if (len(sent_tokenize(comment)) > 1) and (len(sent_tokenize(comment)) < 3)\n",
    "    ]\n",
    "    comments = [comment for comment in comments if re.search(\"(?P<url>https?://[^\\s]+)\", comment) is None]\n",
    "\n",
    "    if comments:\n",
    "        example[\"response\"] = np.random.choice(comments, 1)[0]\n",
    "        print(example[\"response\"])\n",
    "\n",
    "    return example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "02c91152",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_column = [[]] * len(nsfw_dataset)\n",
    "nsfw_dataset = nsfw_dataset.add_column(\"rots\", new_column)\n",
    "new_column = [None] * len(nsfw_dataset)\n",
    "nsfw_dataset = nsfw_dataset.add_column(\"safety_label\", new_column)\n",
    "new_column = [\"None\"] * len(nsfw_dataset)\n",
    "nsfw_dataset = nsfw_dataset.add_column(\"response\", new_column)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d93ef036",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /home/shahul/.cache/huggingface/datasets/jjmachan___parquet/jjmachan--NSFW-questions-inter-cleaned_df-90257bbb92c45b4c/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-a06a0f349e8c04ee.arrow\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b414e202375e495ba6bf6bd69758a78b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/12858 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d1338dcf427490cbee6f1b64e376ea1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "nsfw_dataset = nsfw_dataset.map(filter_stopwords)\n",
    "nsfw_dataset = nsfw_dataset.map(add_rot_label)\n",
    "nsfw_dataset = nsfw_dataset.filter(lambda example: example[\"safety_label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "604d5a67",
   "metadata": {},
   "source": [
    "## Comments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "277490d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "post_ids = [item[\"post_id\"] for item in nsfw_dataset]\n",
    "comments_df = get_comments(post_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9dd8d746",
   "metadata": {},
   "outputs": [],
   "source": [
    "nsfw_dataset = nsfw_dataset.map(select_response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0c02e1f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "nsfw_dataset = nsfw_dataset.rename_columns({\"title\": \"user\"})\n",
    "nsfw_dataset = nsfw_dataset.remove_columns([\"C1\", \"C2\", \"C3\", \"C4\", \"C5\", \"link_flair_text\", \"score\", \"upvote_ratio\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3b2e1237",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'user': 'Why are condoms not always effective?',\n",
       " 'subreddit': 'sexeducation',\n",
       " 'post_id': 'le7znt',\n",
       " 'is_self': True,\n",
       " 'over_18': False,\n",
       " 'is_question': True,\n",
       " 'rots': [\"It's good to use condoms to protect yourself during sex\"],\n",
       " 'safety_label': '__needs_caution__',\n",
       " 'response': \"They are absolutely weighing in the percentages of condoms breaking, slipping off, being used incorrectly, etc. \\n\\nIf a condom is used correctly, disposed of correctly, and doesn't break at any point, there's 0% chance of pregnancy as far as I'm aware.\"}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nsfw_dataset = nsfw_dataset.shuffle()\n",
    "nsfw_dataset[35]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "98f8a47a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7002d455f21849a4a87f9e70efbf58fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Flattening the indices:   0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "new_column = [True] * len(nsfw_dataset)\n",
    "nsfw_dataset = nsfw_dataset.add_column(\"episode_done\", new_column)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e7b05ce7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bbf37be7281141e3a7b33d3dfd7d7ca4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d13b7b237124fdcaaa8151a10bdb76d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0e010c78b4f5458980e1ce263b95afb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03101809bcc0410a89e5203913d33911",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a1897223ddf48eaae71cc29b9fad0dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading metadata:   0%|          | 0.00/651 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Updating downloaded metadata with the new split.\n"
     ]
    }
   ],
   "source": [
    "nsfw_dataset.push_to_hub(\"shahules786/prosocial-nsfw\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b288f950",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
